# mypy: allow-untyped-defs
import dataclasses
import operator
import sys
from collections.abc import Callable
from enum import Enum
from typing import Optional

import torch

from .. import cpp_builder, ir
from ..cpu_vec_isa import (
    pick_vec_isa,
    VecAMX,
    VecAVX2,
    VecAVX512,
    VecISA,
    VecNEON,
    VecSVE256,
)
from ..utils import IndentedBuffer, parallel_num_threads
from ..virtualized import V
from .common import KernelTemplate
from .cpp_template_kernel import CppTemplateKernel
from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp


class LayoutType(Enum):
    NORMAL = 0
    VNNI2 = 1
    VNNI4 = 2


_IS_WINDOWS = sys.platform == "win32"


def get_restrict_keyword() -> str:
    if _IS_WINDOWS:
        # https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170
        return "__restrict"
    else:
        return "__restrict__"


class CppMicroGemm:
    """
    A class that codegens a kernel that computes small-sized matrix multiplication.

    A micro GEMM kernel is responsible for register blocking, instruction selection,
    and other CPU architecture-specific optimizations.

    The subclasses need to override `codegen_define` to define the kernel function
    that is called by the code generated by `codegen_call`.
    """

    # TODO(jgong5): support constant shapes and lds as template args.
    DECLARE_KERNEL = r"""
template <bool accum, bool prefetch=false>
inline void {{kernel_name}}(
{%- if kernel_extra_args_declare %}
    {{kernel_extra_args_declare}}
{%- endif %}
    const {{input_t}}* {{restrict_keyword}} A,
    const {{input2_t}}* {{restrict_keyword}} B,
    {{output_t}}* {{restrict_keyword}} C,
    int64_t M,
    int64_t N,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc
)
"""

    def __init__(
        self,
        name,
        input_dtype,
        input2_dtype,
        output_dtype,
        compute_dtype,
        register_blocking,
        alpha=1,
    ) -> None:
        self.name = name
        self.input_dtype = input_dtype
        assert input2_dtype is not None
        self.input2_dtype = input2_dtype
        self.output_dtype = output_dtype
        self.compute_dtype = compute_dtype
        self.register_blocking = register_blocking
        self.alpha = alpha
        self.pack_vnni_B_locally = False

    def get_common_options(self):
        if self.input_dtype in [torch.uint8, torch.int8]:
            assert self.compute_dtype == torch.int32
            assert self.output_dtype == torch.int32
            assert self.input2_dtype == torch.int8
        return {
            "torch": torch,
            "kernel_name": self.name,
            "input_dtype": self.input_dtype,
            "input2_dtype": self.input2_dtype,
            "output_dtype": self.output_dtype,
            "compute_dtype": self.compute_dtype,
            "input_t": DTYPE_TO_CPP[self.input_dtype],
            "input2_t": DTYPE_TO_CPP[self.input2_dtype],
            "output_t": DTYPE_TO_CPP[self.output_dtype],
            "compute_t": DTYPE_TO_CPP[self.compute_dtype],
            "alpha": self.alpha,
            "kernel_extra_args_declare": self.get_kernel_extra_args_declare(),
            "int8_gemm": self.input_dtype in [torch.uint8, torch.int8],
            "vnni_size": 4 if self.input_dtype in [torch.uint8, torch.int8] else 2,
            "restrict_keyword": get_restrict_keyword(),
            "pack_vnni_B_locally": self.pack_vnni_B_locally,
            "template": self,
            "is_woq_int4": self.is_woq_int4(),
        }

    def get_kernel_declaration(self):
        options = self.get_common_options()
        return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options)

    def get_kernel_extra_args_declare(self) -> str:
        return ""

    def get_kernel_extra_args(self, **kwargs) -> list[str]:
        return []

    def codegen_define(self, kernel: CppTemplateKernel) -> str:
        raise NotImplementedError

    def codegen_call(
        self,
        kernel: CppTemplateKernel,
        A: ir.Buffer,
        B: ir.Buffer,
        C: ir.Buffer,
        accum: bool,
        prefetch: bool = False,
        **kwargs_for_extra_args,
    ) -> str:
        """
        Generate the code for calling the templated kernel that computes
        `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise.
        """
        A_ptr = f"&({kernel.index(A, [0, 0])})"
        B_ptr = f"&({kernel.index(B, [0, 0])})"
        C_ptr = f"&({kernel.index(C, [0, 0])})"
        M = kernel.size(C, 0)
        N = kernel.size(C, 1)
        K = kernel.size(A, 1)
        lda = kernel.stride(A, 0)
        ldb = kernel.stride(B, 0)
        ldc = kernel.stride(C, 0)
        res = IndentedBuffer()
        res.writeline(
            f"{self.name}<{value_to_cpp(accum, 'bool')}, {value_to_cpp(prefetch, 'bool')}>("
        )
        with res.indent():
            kwargs_for_extra_args.update({"kernel": kernel})
            extra_args = self.get_kernel_extra_args(**kwargs_for_extra_args)
            for arg in extra_args:
                res.writeline(arg)
            res.writeline(f"{A_ptr},")
            res.writeline(f"{B_ptr},")
            res.writeline(f"{C_ptr},")
            res.writeline(f"{M},")
            res.writeline(f"{N},")
            res.writeline(f"{K},")
            res.writeline(f"{lda},")
            res.writeline(f"{ldb},")
            res.writeline(f"{ldc}")
        res.writeline(");")
        return res.getvalue()

    def use_local_vnni_blocking(self, should_block_weight: bool):
        self.pack_vnni_B_locally = should_block_weight

    def codegen_init(
        self,
        kernel: CppTemplateKernel,
    ) -> str:
        return ""

    def codegen_finalize(
        self,
        kernel: CppTemplateKernel,
    ) -> str:
        return ""

    def get_b_layout(self) -> LayoutType:
        return LayoutType.NORMAL

    ALLOCATE_WEIGHT_BUFFER = r"""
    {%- if is_msvc_compiler %}
    // MSVC doesn't support stack-allocated dynamic-sized arrays, so using heap memory here.
    auto heap_deq_b_buf_ptr = std::make_unique<{{buffer_dtype}}[]>({{buffer_size}});
    {{buffer_dtype}}* {{buffer_name}} = heap_deq_b_buf_ptr.get();
    {%- else %}
    // It's safe to use a stack-allocated array since the blocking strategy would
    // require us to allocate an array that's smaller than the size of L1D cache,
    // and the default per thread max stack size on Linux is quite higher,
    // so we need not worry about stack overflow.
    alignas(4096) {{buffer_dtype}} {{buffer_name}}[{{buffer_size}}];
    {%- endif %}
"""

    def codegen_allocate_weight_buffer(
        self, buffer_name: str, buffer_dtype: str, *size_args
    ) -> str:
        buffer_size = " * ".join(map(str, size_args))
        return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render(
            {
                "buffer_name": buffer_name,
                "buffer_dtype": buffer_dtype,
                "buffer_size": buffer_size,
                "is_msvc_compiler": cpp_builder.is_msvc_cl(),
            }
        )

    def is_woq_int4(self):
        return False


@dataclasses.dataclass
class CppMicroGemmConfig:
    input_dtype: torch.dtype
    input2_dtype: torch.dtype
    output_dtype: torch.dtype
    compute_dtype: torch.dtype
    vec_isa_cls: type[VecISA]
    register_blocking: GemmBlocking
    extra_check: Optional[Callable[..., bool]] = None


micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {}


def register_micro_gemm(*configs):
    def inner(cls):
        assert cls not in micro_gemm_configs, (
            f"Duplicate micro_gemm registration for {cls}"
        )
        assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
        micro_gemm_configs[cls] = list(configs)
        return cls

    return inner


def generate_gemm_config(
    vec_isa_cls,
    register_blockings,
    input_dtype=torch.float,
    input2_dtype=None,
    output_dtype=None,
    compute_dtype=None,
    extra_check=None,
):
    if output_dtype is None:
        output_dtype = input_dtype
    if compute_dtype is None:
        compute_dtype = output_dtype
    if input2_dtype is None:
        input2_dtype = input_dtype
    return [
        CppMicroGemmConfig(
            input_dtype,
            input2_dtype,
            output_dtype,
            compute_dtype,
            vec_isa_cls,
            GemmBlocking(*blocking),
            extra_check,
        )
        for blocking in register_blockings
    ]


class CppMicroGemmRef(CppMicroGemm):
    """
    A reference implementation of the CppMicroGemm class with naive C++ code.
    It is used for correctness debugging.
    """

    TEMPLATE_ENTRY = r"""
{{declare_kernel}} {
    for (int64_t m = 0; m < M; ++m) {
        for (int64_t n = 0; n < N; ++n) {
            {{compute_t}} result = accum ? C[m * ldc + n] : 0;
            for (int64_t k = 0; k < K; ++k) {
                result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}};
            }
            C[m * ldc + n] = result;
        }
    }
}
"""

    def __init__(
        self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
    ) -> None:
        super().__init__(
            name,
            input_dtype,
            input2_dtype,
            output_dtype,
            compute_dtype,
            GemmBlocking(1, 1, 1),
            alpha,
        )

    def codegen_define(self, kernel: CppTemplateKernel) -> str:
        options = {
            "declare_kernel": self.get_kernel_declaration(),
            **self.get_common_options(),
        }
        return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)


def is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k):
    return (
        k % config.register_blocking.block_k == 0
        and n % config.register_blocking.block_n == 0
        and m < 16
    )


# extra check for small M dimension for int8 WoQ case
def check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs):
    return is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k) and not kwargs.get(
        "dynamic_M", False
    )


# For int8 WoQ GEMM with small M, we use different blockings that shouldn't be used otherwise
def do_not_use_with_small_m_for_int8_woq(config, m, n, k, alpha, num_threads, **kwargs):
    return not check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs)


@register_micro_gemm(
    *generate_gemm_config(
        VecAVX512,
        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
        input_dtype=torch.float,
    ),
    *generate_gemm_config(
        VecAVX512,
        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
        input_dtype=torch.bfloat16,
        output_dtype=torch.float,
    ),
    *generate_gemm_config(
        VecAVX512,
        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
        input_dtype=torch.half,
        output_dtype=torch.float,
    ),
    *generate_gemm_config(
        VecAVX512,
        [(8, 48, 1), (8, 32, 1), (16, 16, 1)],
        input_dtype=torch.bfloat16,
        input2_dtype=torch.int8,
        output_dtype=torch.float,
        compute_dtype=torch.float,
        extra_check=do_not_use_with_small_m_for_int8_woq,
    ),
    *generate_gemm_config(
        VecAVX512,
        [
            (4, 32, 64),
            (8, 32, 64),
        ],
        input_dtype=torch.bfloat16,
        input2_dtype=torch.int8,
        output_dtype=torch.float,
        compute_dtype=torch.float,
        extra_check=check_int8_woq_small_m_dim,
    ),
    *generate_gemm_config(
        VecAVX2,
        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
        input_dtype=torch.float,
    ),
    *generate_gemm_config(
        VecAVX2,
        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
        input_dtype=torch.bfloat16,
        output_dtype=torch.float,
    ),
    *generate_gemm_config(
        VecAVX2,
        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
        input_dtype=torch.half,
        output_dtype=torch.float,
    ),
    *generate_gemm_config(
        VecAVX2,
        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
        input_dtype=torch.bfloat16,
        input2_dtype=torch.int8,
        output_dtype=torch.float,
        compute_dtype=torch.float,
        extra_check=do_not_use_with_small_m_for_int8_woq,
    ),
    *generate_gemm_config(
        VecAVX2,
        [
            (2, 16, 64),
            (4, 16, 64),
        ],
        input_dtype=torch.bfloat16,
        input2_dtype=torch.int8,
        output_dtype=torch.float,
        compute_dtype=torch.float,
        extra_check=check_int8_woq_small_m_dim,
    ),
    *generate_gemm_config(
        VecNEON,
        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
        input_dtype=torch.float,
        input2_dtype=torch.float,
        output_dtype=torch.float,
        compute_dtype=torch.float,
    ),
    *generate_gemm_config(
        VecSVE256,
        [(4, 24, 1), (4, 16, 1), (8, 8, 1)],
        input_dtype=torch.float,
        input2_dtype=torch.float,
        output_dtype=torch.float,
        compute_dtype=torch.float,
    ),
)
class CppMicroGemmFP32Vec(CppMicroGemm):
    """
    This class generates the code for micro gemm using fp32 vec instructions for compute.
    It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output.
    The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template,
    if the desired output is BF16/FP16.
    """

    TEMPLATE_ENTRY = r"""
{{declare_kernel}} {
    using Vectorized = at::vec::Vectorized<{{compute_t}}>;
    constexpr auto VLEN = Vectorized::size();
    {{kernel.assert_function}}({{block_n}} % VLEN == 0, "block_n dimension must be multiple of Vector size");
    {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
    // TODO(jgong5): loop unroll for M and N
    for (int64_t m = 0; m < M; m += {{block_m}}) {
        int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
        for (int64_t n = 0; n < N; n += {{block_n}}) {
            int64_t block_n = std::min<int64_t>(N - n, {{block_n}});
            if (block_m == {{block_m}} && block_n == {{block_n}}) {
{%- if not trans_b %}
                {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
{%- else %}
                {{kernel_name}}_transpose_b_kernel<{{block_m}}, {{block_n}}, accum, prefetch>(
{%- endif %}
                    A + m * lda,
{%- if not trans_b %}
                    B + n,
{%- else %}
                    B + n * ldb,
{%- endif %}
                    C + m * ldc + n,
                    K,
                    lda,
                    ldb,
                    ldc
                );
{%- if tail_n %}
            } else if (block_n == {{block_n}}){
{%- else %}
            } else {
{%- endif %}
                switch (block_m) {
{%- for b in range(block_m - 1, 0, -1) %}
                case {{b}}:
    {%- if not trans_b %}
                    {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum, prefetch>(
    {%- else %}
                    {{kernel_name}}_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
    {%- endif %}
                        A + m * lda,
    {%- if not trans_b %}
                        B + n,
    {%- else %}
                        B + n * ldb,
    {%- endif %}
                        C + m * ldc + n,
                        K,
                        lda,
                        ldb,
                        ldc
                    );
                    break;
{%- endfor %}
                default:
                    {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}");
                }

{%- if tail_n %}
            } else {
                switch (block_m) {
    {%- for b in range(block_m, 0, -1) %}
                case {{b}}:
        {%- if not trans_b %}
                    {{kernel_name}}_ntail_kernel<{{b}}, {{block_n}}, accum, prefetch>(
        {%- else %}
                    {{kernel_name}}_ntail_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>(
        {%- endif %}
                        A + m * lda,
        {%- if not trans_b %}
                        B + n,
        {%- else %}
                        B + n * ldb,
        {%- endif %}
                        C + m * ldc + n,
                        block_n,
                        K,
                        lda,
                        ldb,
                        ldc
                    );
                    break;
    {%- endfor %}
                default:
                    {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}");
                }
            }
{%- else %}
            }
{%- endif %}
        }
    }
}
"""

    TEMPLATE_KERNEL = r"""

template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum, bool prefetch=false>
{%- if not trans_b %}
    {%- if tail_n %}
inline void {{kernel_name}}_ntail_kernel(
    {%- else %}
inline void {{kernel_name}}_kernel(
    {%- endif %}
{%- else %}
    {%- if tail_n %}
inline void {{kernel_name}}_ntail_transpose_b_kernel(
    {%- else %}
inline void {{kernel_name}}_transpose_b_kernel(
    {%- endif %}
{%- endif %}
    const {{input_t}}* {{restrict_keyword}} A,
    const {{input2_t}}* {{restrict_keyword}} B,
    {{output_t}}* {{restrict_keyword}} C,
{%- if tail_n %}
    int64_t N,
{%- endif %}
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc
) {
    using Vectorized = at::vec::Vectorized<{{compute_t}}>;
{%- if input2_dtype in [torch.bfloat16, torch.float16] %}
    using VectorizedIn = at::vec::Vectorized<{{input_t}}>;
{%- endif %}

{%- if not trans_b %}
    constexpr auto VLEN = Vectorized::size();
    constexpr auto ROWS = BLOCK_M;
    constexpr auto COLS = BLOCK_N / VLEN;

    Vectorized va;
    at::vec::VectorizedN<{{compute_t}}, COLS> vb;
    at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;

    {%- if tail_n %}
    int64_t rCOLS = (N + VLEN - 1) / VLEN;
    int ntail = N % VLEN;
    {%- endif %}
    auto loadc = [&](auto i) {
        if constexpr (accum) {
            constexpr int row = i / COLS;
            constexpr int col = i % COLS;
    {%- if tail_n %}
            int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
            if (col < rCOLS) {
                vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN, load_size);
            }
    {%- else %}
            vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
    {%- endif %}
        } else {
            vc[i] = Vectorized(0.0f);
        }
    };
    c10::ForcedUnroll<ROWS * COLS>{}(loadc);

    auto compute = [&, COLS](auto i, int k) {
        constexpr int row = i / COLS;
        constexpr int col = i % COLS;
    {%- if tail_n %}
        int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
    {%- endif %}
        if constexpr (col == 0) {
    {%- if alpha != 1 %}
            va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}});
    {%- else %}
            va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]));
    {%- endif %}
        }

        if constexpr (row == 0) {
    {%- if tail_n %}
            if (col < rCOLS) {
        {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
                auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, load_size);
                vb[col] = at::vec::convert<{{compute_t}}>(b);
        {%- elif input2_dtype == torch.int8 %}
            // Convert VLEN int8 elements to int32, and then fp32
                auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN, load_size);
                vb[col] = at::vec::convert<float>(b32);
        {%- else %}
                vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN, load_size);
        {%- endif %}
            } else {
                vb[col] = Vectorized(0.0f);
            }

    {%- else %}

        {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
            auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN);
            vb[col] = at::vec::convert<{{compute_t}}>(b);
        {%- elif input2_dtype == torch.int8 %}
            // Convert VLEN int8 elements to int32, and then fp32
            auto b32 = at::vec::convert_to_int32<int8_t>(B + k * ldb + col * VLEN);
            if constexpr (prefetch) {
              _mm_prefetch(B + (k + {{block_k}}) * ldb + col * VLEN, _MM_HINT_T0);
            }
            vb[col] = at::vec::convert<float>(b32);
        {%- else %}
            vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN);
        {%- endif %}
    {%- endif %}

        }

        constexpr int idx = row * COLS + col;
    {%- if tail_n %}
        if (col < rCOLS) {
            vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
        }
    {%- else %}
        vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]);
    {%- endif %}
    };

    for (int k = 0; k < K; ++k) {
        c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
    }

    // store to C
    auto storec = [&](auto i) {
        constexpr int row = i / COLS;
        constexpr int col = i % COLS;
    {%- if tail_n %}
        int store_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN;
        if (col < rCOLS) {
            vc[i].store(C + row * ldc + col * VLEN, store_size);
        }
    {%- else %}
        vc[i].store(C + row * ldc + col * VLEN);
    {%- endif %}
    };
    c10::ForcedUnroll<ROWS * COLS>{}(storec);

{%- else %}
    // Use 2 implementations for the transposed B:
    // First implementation:
    //   Transpose first and then perform outer product calculation in sub-blocks,
    //   which introduces an additional transpose overhead of [K, N] compared to the non-transpose version.
    // Second implementation:
    //   Directly perform inner product calculation in sub-blocks,
    //   which introduces an additional vector reduction of [M, N] compared to the non-tranpose version.
    // Therefore, when M * N / (K * N) is large, the first implementation has better performance.
    {%- if tail_n %}
    if (K % Vectorized::size() == 0 && N % Vectorized::size() == 0 && 24 * BLOCK_M > K) {
    {%- else %}
    if (K % Vectorized::size() == 0 && 24 * BLOCK_M > K) {
    {%- endif %}
        // First implementation:
        constexpr auto VLEN = Vectorized::size();
        constexpr auto ROWS = BLOCK_M;
        constexpr auto COLS = BLOCK_N / VLEN;
        int _K = K / VLEN;
        Vectorized va;
        at::vec::VectorizedN<{{compute_t}}, VLEN> vb;
        at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc;
        auto loadc = [&](auto i) {
            if constexpr (accum) {
                constexpr int row = i / COLS;
                constexpr int col = i % COLS;
                vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN);
            } else {
                vc[i] = Vectorized(0.0f);
            }
        };
        c10::ForcedUnroll<ROWS * COLS>{}(loadc);
        auto unroll_loadB = [&](auto i, const {{input2_t}}* {{restrict_keyword}} src_ptr) {
    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
            auto b = VectorizedIn::loadu(src_ptr + i * ldb, VLEN);
            vb[i] = at::vec::convert<{{compute_t}}>(b);
    {%- elif input2_dtype == torch.int8 %}
            auto b32 = at::vec::convert_to_int32<int8_t>(src_ptr + i * ldb, VLEN);
            vb[i] = at::vec::convert<float>(b32);
    {%- else %}
            vb[i] = Vectorized::loadu(src_ptr + i * ldb, VLEN);
    {%- endif %}
        };
        auto compute_trans = [&, COLS](auto i, int k) {
            constexpr int row = i % ROWS;
            constexpr int col = i / ROWS;
            constexpr int e_col = col * VLEN;
            int idk = k * VLEN;
            if constexpr (row == 0) {
                c10::ForcedUnroll<VLEN>{}(unroll_loadB, B + e_col * ldb + idk);
                at::vec::transpose_block(vb);
            }
            constexpr int idx = row * COLS + col;
            {{kernel.unroll_pragma(16)}}
            for (int j = 0; j < VLEN; j++) {
    {%- if alpha != 1 %}
                va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j]) * {{alpha}});
    {%- else %}
                va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j]));
    {%- endif %}
                vc[idx] = at::vec::fmadd(va, vb[j], vc[idx]);
            }
        };
        for (int k = 0; k < _K; ++k) {
            c10::ForcedUnroll<ROWS * COLS>{}(compute_trans, k);
        }
        // store to C
        auto storec = [&](auto i) {
            constexpr int row = i / COLS;
            constexpr int col = i % COLS;
            vc[i].store(C + row * ldc + col * VLEN);
        };
        c10::ForcedUnroll<ROWS * COLS>{}(storec);
    } else {
        // Second implementation
    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
        constexpr auto VLEN = VectorizedIn::size();
    {%- else %}
        constexpr auto VLEN = Vectorized::size();
    {%- endif %}
        int _K = (K + VLEN - 1) / VLEN;
        // sub-block size of BLOCK_N and BLOCK_M
        constexpr int sM = {{sub_block_m}};
        constexpr int sN = {{sub_block_n}};
    {%- if tail_n %}
        int bN = (N + sN - 1) / sN;
    {%- else %}
        constexpr int bN = (BLOCK_N + sN - 1) / sN;
    {%- endif %}
        constexpr int bM = (BLOCK_M + sM - 1) / sM;

    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
        at::vec::VectorizedN<{{compute_t}}, 2> va;
        at::vec::VectorizedN<{{compute_t}}, 2 * sN> vb;
    {%- else %}
        at::vec::Vectorized<{{compute_t}}> va;
        at::vec::VectorizedN<{{compute_t}}, sN> vb;
    {%- endif %}
        at::vec::VectorizedN<{{compute_t}}, sN * sM> vmid;

    {%- if tail_n %}
        int ntail = N % sN;
    {%- else %}
        constexpr int ntail = BLOCK_N % sN;
    {%- endif %}
        constexpr int mtail = BLOCK_M % sM;
        int ktail = K % VLEN;

        auto compute_trans = [&](int m, int n, int k) {
    {%- if tail_n %}
            int e_n = (n == bN - 1 && ntail != 0) ? (N - n * sN) : sN;
    {%- else %}
            int e_n = (n == bN - 1 && ntail != 0) ? (BLOCK_N - n * sN) : sN;
    {%- endif %}
            int e_m = (m == bM - 1 && mtail != 0) ? (BLOCK_M - m * sM) : sM;
            int e_k = (k == _K - 1 && ktail != 0) ? (K - k * VLEN) : VLEN;
            {{kernel.unroll_pragma(sub_block_n)}}
            for (int i = 0; i < e_n; i++) {
    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
                auto b = VectorizedIn::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k);
                std::tie(vb[2 * i], vb[2 * i + 1]) = at::vec::convert_to_float<{{input_t}}>(b);
    {%- elif input2_dtype == torch.int8 %}
                auto b32 = at::vec::convert_to_int32<int8_t>(B + (sN * n + i) * ldb + k * VLEN, e_k);
                vb[i] = at::vec::convert<float>(b32);
    {%- else %}
                vb[i] = Vectorized::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k);
    {%- endif %}
            }

            {{kernel.unroll_pragma(sub_block_m)}}
            for (int s = 0; s < e_m; s++) {
    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
                auto a = VectorizedIn::loadu(A + (sM * m + s) * lda + k * VLEN, e_k);
                std::tie(va[0], va[1]) = at::vec::convert_to_float<{{input_t}}>(a);
    {%- elif input2_dtype == torch.int8 %}
                auto a32 = at::vec::convert_to_int32<int8_t>(A + (sM * m + s) * lda + k * VLEN, e_k);
                va = at::vec::convert<float>(a32);
    {%- else %}
                va = Vectorized::loadu(A + (sM * m + s) * lda + k * VLEN, e_k);
    {%- endif %}

    {%- if alpha != 1 %}
                va = va * Vectorized({{alpha}});
    {%- endif %}
                if (k == 0) {
                    {{kernel.unroll_pragma(sub_block_n)}}
                    for (int i = 0; i < e_n; i++) {
    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
                        vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], Vectorized(0.0f));
                        vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]);
    {%- else %}
                        vmid[sN * s + i] = at::vec::fmadd(va, vb[i], Vectorized(0.0f));
    {%- endif %}
                    }
                } else {
                    {{kernel.unroll_pragma(sub_block_n)}}
                    for (int i = 0; i < e_n; i++) {
    {%- if input2_dtype in [torch.bfloat16, torch.float16] %}
                        vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], vmid[sN * s + i]);
                        vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]);
    {%- else %}
                        vmid[sN * s + i] = at::vec::fmadd(va, vb[i], vmid[sN * s + i]);
    {%- endif %}
                    }
                }
            }

            // store to C
            if (k == _K - 1) {
                {{kernel.unroll_pragma(sub_block_m)}}
                for (int s = 0; s < e_m; s++) {
                    {{kernel.unroll_pragma(sub_block_n)}}
                    for (int i = 0; i < e_n; i++) {
                        auto v = at::vec::vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, vmid[sN * s + i]);
                        if constexpr (accum) {
                            auto c = *(C + (sM * m + s) * ldc + sN * n + i);
                            *(C + (sM * m + s) * ldc + sN * n + i) = c + v;
                        } else {
                            *(C + (sM * m + s) * ldc + sN * n + i) = v;
                        }
                    }
                }
            }
        };

        for (int n = 0; n < bN; ++n) {
            for (int m = 0; m < bM; ++m) {
                for (int k = 0; k < _K; ++k) {
                    compute_trans(m, n, k);
                }
            }
        }
    }
{%- endif %}
}
"""

    # set trans_b to generate gemm that supports transposed B matrix
    # set tail_n to support the tail of N
    # TODO add trans_b support for other micro gemms
    # and move setting of trans_b to the init of CppMicroGemm
    def __init__(
        self,
        name,
        input_dtype,
        input2_dtype,
        output_dtype,
        compute_dtype,
        register_blocking,
        alpha=1,
        tail_n=False,
        trans_b=False,
    ) -> None:
        super().__init__(
            name,
            input_dtype,
            input2_dtype,
            output_dtype,
            compute_dtype,
            register_blocking,
            alpha,
        )
        self.tail_n = tail_n
        # trans_b is only supported on platforms that
        # support avx512 or avx2 since transpose_block is
        # only implemented on these platforms
        if trans_b:
            vec_isa = pick_vec_isa()
            assert issubclass(vec_isa.__class__, VecAVX512) or issubclass(
                vec_isa.__class__, VecAVX2
            )
        self.trans_b = trans_b

    def codegen_define(self, kernel: CppTemplateKernel) -> str:
        options = {
            "declare_kernel": self.get_kernel_declaration(),
            "kernel": kernel,
            "block_m": self.register_blocking.block_m,
            "block_n": self.register_blocking.block_n,
            "block_k": self.register_blocking.block_k,
            "trans_b": False,
            "tail_n": False,
            "restrict_keyword": get_restrict_keyword(),
            **self.get_common_options(),
        }
        if self.trans_b:
            # TODO supports tuning of sub_block_m/sub_block_n
            # to get better performance for specific shapes
            sub_block_m = min(1, self.register_blocking.block_m)
            sub_block_n = min(4, self.register_blocking.block_n)
            # update options to generate kernel with trans_b and sub-block size
            options.update(
                {
                    "trans_b": self.trans_b,
                    "sub_block_m": sub_block_m,
                    "sub_block_n": sub_block_n,
                }
            )
        result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
            options
        )
        # update options to generate the kernel for the tail of N
        if self.tail_n:
            options.update(
                {
                    "tail_n": self.tail_n,
                }
            )
            result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
                options
            )
        result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
            options
        )
        return result


# extra check for CppMicroGemmAMX
def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
    vnni_size = 4 if config.input_dtype in [torch.uint8, torch.int8] else 2
    return k % vnni_size == 0 and alpha == 1


def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
    # We need avx512_bf16 to dequant int8 to bf16
    vec_isa = kwargs.get("vec_isa")
    assert vec_isa is not None
    return vec_isa.is_avx512_bf16_supported() and check_amx_extra(
        config, m, n, k, alpha, num_threads, **kwargs
    )


# amx_fp16 need to be checked separately since it is not always supported when amx is supported
def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs):
    assert config.input_dtype == torch.float16 and config.output_dtype == torch.float
    vec_isa = kwargs.get("vec_isa")
    assert vec_isa is not None
    vnni_size = 2
    return vec_isa.is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1


@register_micro_gemm(
    *generate_gemm_config(
        VecAMX,
        [(32, 32, 64), (48, 16, 64)],
        input_dtype=torch.int8,
        input2_dtype=torch.int8,
        output_dtype=torch.int32,
        compute_dtype=torch.int32,
        extra_check=check_amx_extra,
    ),
    *generate_gemm_config(
        VecAMX,
        [(32, 32, 32), (48, 16, 32)],
        input_dtype=torch.bfloat16,
        input2_dtype=torch.int8,
        output_dtype=torch.float,
        compute_dtype=torch.float,
        extra_check=check_int8_bf16_amx_extra,
    ),
    *generate_gemm_config(
        VecAMX,
        [(32, 16, 32), (32, 32, 32), (48, 16, 32), (16, 48, 32)],
        input_dtype=torch.bfloat16,
        output_dtype=torch.float,
        extra_check=check_amx_extra,
    ),
    *generate_gemm_config(
        VecAMX,
        [(32, 32, 32), (48, 16, 32), (16, 48, 32)],
        input_dtype=torch.float16,
        output_dtype=torch.float,
        extra_check=check_amx_fp16_extra,
    ),
    *generate_gemm_config(
        VecAMX,
        [(32, 32, 64), (48, 16, 64)],
        input_dtype=torch.uint8,
        input2_dtype=torch.int8,
        output_dtype=torch.int32,
        compute_dtype=torch.int32,
        extra_check=check_amx_extra,
    ),
)
class CppMicroGemmAMX(CppMicroGemm):
    """
    This class generates the code for micro gemm using Advanced Matrix extension (AMX)
    instructions available in 4th generation Intel Xeon for compute.
    It supports input types of torch.bfloat16 with fp32 output.
    """

    TEMPLATE_ENTRY = r"""
{{declare_kernel}} {
    {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
    {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2");
{%- if pack_vnni_B_locally %}
    {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K", block_n)}}
{%- endif %}
{%- if use_cached_dequantized_B %}
    // Create a stack-allocated buffer for tiles of B.
    // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
    // we cache K * {{block_n}} elements of dequantized B
    {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}}
    const auto buf_size = K * {{block_n}};
    auto load_dequantized_B = [&](int base_idx) {
        // Load a tile of B & cache it in L1D.
        {{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx;
        for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) {
        {%- for vec_idx in range(0, block_n, 32) %}
            _mm_prefetch(base_addr + idx_q + 64 * ldb, _MM_HINT_T0);
            {%- if (block_n - vec_idx) >= 32 %}
            // 1) Load 32 x int8
            __m256i v8  = _mm256_loadu_si256((const __m256i*)(base_addr + idx_q + {{vec_idx}}));
            // 2) Extract two halves
            __m128i v8_lo = _mm256_extracti128_si256(v8, 0);
            __m128i v8_hi = _mm256_extracti128_si256(v8, 1);
            // 3) Widen each half to i32
            __m512i v32_lo = _mm512_cvtepi8_epi32(v8_lo);
            __m512i v32_hi = _mm512_cvtepi8_epi32(v8_hi);
            // 4) Convert to f32
            __m512 f_lo = _mm512_cvtepi32_ps(v32_lo);
            __m512 f_hi = _mm512_cvtepi32_ps(v32_hi);
            // 5) f32 -> bf16 (round-to-nearest-even) and pack 32 lanes to 512b
            // Packs the second operand (f_lo) into the lower 16 bf16 lanes and the first (f_hi) into the upper 16.
            __m512i bf = (__m512i)_mm512_cvtne2ps_pbh(f_hi, f_lo);
            // 6) Store 32 x bf16 (512 bits)
            _mm512_storeu_si512((__m512i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf);
            {%- elif (block_n - vec_idx) >= 16 %}
            // 1) Load 16 x int8 (128 bits)
            __m128i v8 = _mm_loadu_si128((const __m128i*)(base_addr + idx_q + {{vec_idx}}));
            // 2) Widen: 16 x i8 -> 16 x i32
            __m512i v32 = _mm512_cvtepi8_epi32(v8);
            // 3) Convert to f32
            __m512 f32 = _mm512_cvtepi32_ps(v32);
            // 4) Convert f32 -> bf16 (round-to-nearest-even)
            __m256i bf16 = (__m256i)_mm512_cvtneps_pbh(f32);
            // 5) Store 16 x bf16 (256 bits)
            _mm256_storeu_si256((__m256i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf16);
            {%- else %}
            auto b_int8_tail = at::vec::Vectorized<int8_t>::loadu(
                base_addr + idx_q + {{block_n - (block_n % 32)}},
                static_cast<int64_t>({{block_n % 32}})
            );
            auto b_bf16_tail = at::vec::convert<{{input_t}}>(b_int8_tail);
            b_bf16_tail.store(
                dequantized_B_buf + idx_dq + {{block_n - (block_n % 32)}},
                static_cast<int64_t>({{block_n % 32}})
            );
            {%- endif %}
        {%- endfor %}
        }
    };
{%- endif %}
// The ldb would not be block_n if N != block_n
{%- if use_cached_dequantized_B or pack_vnni_B_locally %}
    const int64_t updated_ldb = {{block_n}};
{%- else %}
    const int64_t updated_ldb = ldb;
{%- endif %}
    // TODO(jgong5): loop unroll for M and N
    for (int64_t n = 0; n < N; n += {{block_n}}) {
{%- if pack_vnni_B_locally %}
        // Pack non-constant weights into VNNI interleaved format in packed_B_buf
        at::vec::pack_vnni2(B + n, packed_B_buf, ldb, K, {{block_n}});
{%- elif use_cached_dequantized_B %}
        // Dequantize K * block_n int8 B elements into BF16
        load_dequantized_B(n);
{%- endif %}
        for (int64_t m = 0; m < M; m += {{block_m}}) {
            int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
            int64_t m_tail = m;
{%- for num_rows in range(block_m, 0, -16) %}
    {%- if num_rows != block_m %}
            else
    {%- endif %}
            if (block_m >= {{num_rows}}) {
                {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
                    amx_state,
                    A + m * lda,
{%- if use_cached_dequantized_B %}
                    dequantized_B_buf,
{%- elif pack_vnni_B_locally %}
                    packed_B_buf,
{%- else %}
                    B + n,
{%- endif %}
                    C + m * ldc + n,
                    K,
                    lda,
                    updated_ldb,
                    ldc,
                    16
                );
                block_m -= {{num_rows}};
                m_tail += {{num_rows}};
            }
{%- endfor %}
            if (block_m > 0) {
                {{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
                    amx_state,
                    A + m_tail * lda,
{%- if use_cached_dequantized_B %}
                    dequantized_B_buf,
{%- elif pack_vnni_B_locally %}
                    packed_B_buf,
{%- else %}
                    B + n,
{%- endif %}
                    C + m_tail * ldc + n,
                    K,
                    lda,
                    updated_ldb,
                    ldc,
                    block_m
                );
            }
        }
    }
}
"""

    TEMPLATE_KERNEL = r"""

template <bool accum, bool prefetch=false>
inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
    AMXState& amx_state,
    const {{input_t}}* {{restrict_keyword}} A,
{%- if use_cached_dequantized_B %}
    const {{input_t}}* {{restrict_keyword}} B,
{%- else %}
    const {{input2_t}}* {{restrict_keyword}} B,
{%- endif %}
    {{output_t}}* {{restrict_keyword}} C,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc,
    uint8_t tilecfg_rows
) {
    // TODO(jgong5): add prefetch hint for A, B, C
    auto loadconfig = [](const amx_tilecfg& cfg) {
        _tile_loadconfig(&cfg);
    };
    const auto last_k_offset = K / {{block_k}} * {{block_k}};
    const auto tail_k_size = K - last_k_offset;
    if C10_LIKELY (last_k_offset > 0) {
        amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig);
    } else {
        amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
    }
    auto load_c = [&]() {
{%- for tile_row in range(num_rows // 16) %}
    {%- for tile_col in range(num_columns) %}
        {%- set tile_idx = tile_row * num_columns + tile_col %}
        _tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
    {%- endfor %}
{%- endfor %}
    };
    auto zero_c = [&]() {
{%- for tile_row in range(num_rows // 16) %}
    {%- for tile_col in range(num_columns) %}
        {%- set tile_idx = tile_row * num_columns + tile_col %}
        _tile_zero({{tile_idx}});
    {%- endfor %}
{%- endfor %}
    };

    if constexpr (accum) {
        load_c();
    } else {
        zero_c();
    }

    auto compute = [&](int k) {
{%- set tile_offset_a = num_rows // 16 * num_columns %}
{%- set tile_offset_b = tile_offset_a + num_rows // 16 %}
{%- for tile_row in range(num_rows // 16) %}
    {%- for tile_col in range(num_columns) %}
        {%- set tile_idx_a = tile_offset_a + tile_row %}
        {%- set tile_idx_b = tile_offset_b + tile_col %}
        {%- set tile_idx_c = tile_row * num_columns + tile_col %}
        {%- if tile_col == 0 %}
        _tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}}));
        {%- endif %}
        {%- if tile_row == 0 %}
        _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}}));
        {%- endif %}
        {%- if int8_gemm %}
            {%- if input_dtype == torch.int8 %}
        _tile_dpbssd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
            {%- else %}
        _tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
            {%- endif %}
        {%- else %}
            {%- if input_dtype == torch.float16 %}
        _tile_dpfp16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
            {%- else %}
        _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
            {%- endif %}
        {%- endif %}
    {%- endfor %}
{%- endfor %}
    };

    {{kernel.unroll_pragma(4)}}
    for (int k = 0; k < last_k_offset; k += {{block_k}}) {
        compute(k);
    }

    auto store_c = [&]() {
    // store to C
{%- for tile_row in range(num_rows // 16) %}
    {%- for tile_col in range(num_columns) %}
        {%- set tile_idx = tile_row * num_columns + tile_col %}
        _tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
    {%- endfor %}
{%- endfor %}
    };

    // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
    if C10_UNLIKELY (tail_k_size > 0) {
        if C10_LIKELY (last_k_offset > 0) {
            store_c();
            amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
            load_c();
        }
        compute(last_k_offset);
    }

    store_c();
}
"""

    def codegen_define(self, kernel: CppTemplateKernel) -> str:
        block_m, block_n, block_k = self.register_blocking
        assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX"
        assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX"
        if self.input_dtype in [torch.uint8, torch.int8]:
            assert block_k == 64, "Only support block_k = 64 for AMX INT8"
        else:
            assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16"
        num_columns = block_n // 16
        options = {
            "declare_kernel": self.get_kernel_declaration(),
            "use_cached_dequantized_B": (
                self.input_dtype == torch.bfloat16
                and self.input2_dtype in [torch.int8, torch.uint8]
            ),
            "kernel": kernel,
            "block_m": block_m,
            "block_n": block_n,
            "block_k": block_k,
            "num_columns": num_columns,
            "restrict_keyword": get_restrict_keyword(),
            **self.get_common_options(),
        }
        result = ""
        for num_rows in range(block_m, 0, -16):
            amx_kernel_options = {**options, "num_rows": num_rows}
            result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
                amx_kernel_options
            )
        result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
            options
        )
        return result

    def codegen_init(
        self,
        kernel: CppTemplateKernel,
    ) -> str:
        return "AMXState amx_state;"

    def codegen_finalize(
        self,
        kernel: CppTemplateKernel,
    ) -> str:
        return "amx_state.release([]() { _tile_release(); });"

    def get_kernel_extra_args_declare(self) -> str:
        return "AMXState& amx_state,"

    def get_kernel_extra_args(self, **kwargs) -> list[str]:
        return ["amx_state,"]

    def get_b_layout(self):
        if self.input_dtype in [torch.uint8, torch.int8]:
            return LayoutType.VNNI4
        else:
            return LayoutType.VNNI2


# extra check for CppMicroBrgemm
def check_brgemm_extra(config, m, n, k, alpha, num_threads, **kwargs):
    assert config.input_dtype == torch.half and config.output_dtype == torch.float
    vnni_size = 2
    # use brgemm for Half when amx_fp16 is supported
    return torch.cpu._is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1


@register_micro_gemm(
    *generate_gemm_config(
        VecAMX,
        [(32, 32, 32), (48, 16, 32), (16, 48, 32)],
        input_dtype=torch.half,
        output_dtype=torch.float,
        extra_check=check_brgemm_extra,
    ),
)
class CppMicroBrgemm(CppMicroGemm):
    """
    This class generates the code for micro gemm using oneDNN brgemm.
    It supports input types of torch.half.
    """

    TEMPLATE_ENTRY = r"""
#include <ATen/native/CPUBlas.h>
{{declare_kernel}} {
{%- if pack_vnni_B_locally %}
    {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K * N")}}
    at::vec::pack_vnni2(B, packed_B_buf, ldb, K, N);
{%- endif %}
    at::native::cpublas::brgemm(
      M, N, K,
    {%- if pack_vnni_B_locally %}
      lda, N, ldc,
    {%- else %}
      lda, ldb, ldc,
    {%- endif %}
      accum,
      A,
    {%- if pack_vnni_B_locally %}
      packed_B_buf,
    {%- else %}
      B,
    {%- endif %}
      C);
}
"""

    def codegen_define(self, kernel: CppTemplateKernel) -> str:
        options = {
            "declare_kernel": self.get_kernel_declaration(),
            "kernel": kernel,
            "block_m": self.register_blocking.block_m,
            "block_n": self.register_blocking.block_n,
            "block_k": self.register_blocking.block_k,
            "restrict_keyword": get_restrict_keyword(),
            **self.get_common_options(),
        }
        result = ""
        result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
            options
        )
        return result

    def codegen_finalize(
        self,
        kernel: CppTemplateKernel,
    ) -> str:
        return "at::native::cpublas::brgemm_release();"

    def get_b_layout(self):
        assert self.input_dtype == torch.half and torch.cpu._is_amx_fp16_supported()
        return LayoutType.VNNI2


def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs):
    if alpha != 1:
        return False
    q_group_size = kwargs.get("q_group_size")
    assert q_group_size is not None
    if (
        q_group_size not in [32, 64, 128]
        or k % q_group_size != 0
        or config.register_blocking.block_k > q_group_size
    ):
        return False
    return k % config.register_blocking.block_k == 0 and n % 64 == 0


@register_micro_gemm(
    # TODO: support float/half input
    *generate_gemm_config(
        VecAVX512,
        [(4, 64, 32), (4, 64, 64), (4, 64, 128)],
        input_dtype=torch.bfloat16,
        input2_dtype=torch.uint8,
        output_dtype=torch.float,
        compute_dtype=torch.float,
        extra_check=check_woq_int4_extra,
    ),
)
class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec):
    """
    This class generates the code for WoQ int4 micro gemm using AVX512 intrinsics.
    It is based on the corresponding ATen kernel.
    Shape of packed weight = [N // 64, K, 32], viewed as [N, K // 2]
    Shape of packed ScalesAndZeros = [K // group_size, N, 2]
    """

    TEMPLATE_ENTRY = r"""
{{declare_kernel}} {
    {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
    {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
    auto group_size = q_group_size;
    for (int64_t m = 0; m < M; m += {{block_m}}) {
        int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
        for (int64_t n = 0; n < N; n += {{block_n}}) {
            if (block_m == {{block_m}}) {
                {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
                    A + m * lda,
                    reinterpret_cast<const uint8_t*>(B) + n * ldb,
                    C + m * ldc + n,
                    K,
                    lda,
                    /* ldb */ {{block_n}} / 2,
                    ldc,
                    group_size,
                    ScaleAndZeros + n * 2,
                    lds,
                    k_start
                );
            } else {
                switch (block_m) {
                {%- for b in range(block_m - 1, 0, -1) %}
                case {{b}}:
                    {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
                        A + m * lda,
                        reinterpret_cast<const uint8_t*>(B) + n * ldb,
                        C + m * ldc + n,
                        K,
                        lda,
                        /* ldb */ {{block_n}} / 2,
                        ldc,
                        group_size,
                        ScaleAndZeros + n * 2,
                        lds,
                        k_start
                    );
                    break;
                {%- endfor %}
                default:
                    {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m);
                }
            }
        }
    }
}
"""

    TEMPLATE_KERNEL = r"""
inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) {
  return (k_start + index) % group_size == 0;
}

inline __m128i {{kernel_name}}_convert_int4_to_int8(const uint8_t* data) {
  __m128i tmp = _mm_loadu_si64((const __m128i*)data);
  __m128i bytes = _mm_cvtepu8_epi16(tmp);
  const __m128i lowMask = _mm_set1_epi8(0xF);
  __m128i high = _mm_andnot_si128(lowMask, bytes);
  __m128i low = _mm_and_si128(lowMask, bytes);
  high = _mm_slli_epi16(high, 4);
  bytes = _mm_or_si128(low, high);
  return bytes;
}

template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
inline void {{kernel_name}}_kernel(
    const {{input_t}}* {{restrict_keyword}} A,
    const uint8_t* {{restrict_keyword}} B,
    {{output_t}}* {{restrict_keyword}} C,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc,
    int64_t q_group_size,
    const at::BFloat16* {{restrict_keyword}} ScaleAndZeros,
    int64_t lds, // leading dimension of ScaleAndZeros
    int64_t k_start) {
  constexpr int BLOCK_K = {{block_k}};
  constexpr int ROWS = BLOCK_M;
  constexpr int COLS = BLOCK_N / 16;

  const int PREFETCH_SIZE_K = 16 * 4;
  const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;

  // number of blocks on K
  const int KB = K / BLOCK_K;

  __m512 va;
  __m512 vb[COLS];
  __m512 vc[ROWS * COLS];
  __m512 scale[COLS];
  __m512 zero[COLS];

  // Lookup table to de-quantize int4 values to bf16.
  // Values are dequantized as truly int4 [-8, 7] range;
  //
  // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
  //
  static const __m512 lut = _mm512_set_ps(
      7.0f, 6.0f, 5.0f, 4.0f,
      3.0f, 2.0f, 1.0f, 0.0f,
      -1.0f, -2.0f, -3.0f, -4.0f,
      -5.0f, -6.0f, -7.0f, -8.0f);

  // index for transpose
  static const __m512i idx1 = _mm512_set_epi32(
      30, 28, 26, 24, 22, 20, 18, 16,
      14, 12, 10, 8, 6, 4, 2, 0);
  static const __m512i idx2 = _mm512_set_epi32(
      31, 29, 27, 25, 23, 21, 19, 17,
      15, 13, 11, 9, 7, 5, 3, 1);

  // load scale and zero point
  auto load_scale_and_zeros = [&](int i, int _kb) {
    // load 2x bfloat16 vector
    __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i));
    _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0);

    // convert to 2x f32 vector
    __m512 a, b;
    at::vec::cvtbf16_fp32(t, a, b);

    // transpose scale_and_zero from {16, 2} to {2, 16}
    // inputs:
    //   a: {s0, z0, s1, z1, ..., s7, z7}
    //   b: {s8, z8, s9, z9, ..., s15, z15}
    // output:
    //   scale: {s0, s1, s2, ..., s15}
    //   zero:  {z0, z1, z2, ..., z15}
    scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
    zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
  };

  auto loadc = [&](auto i) {
    if constexpr (accum) {
       constexpr int row = i / COLS;
       constexpr int col = i % COLS;
       vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16);
    } else {
      vc[i] = _mm512_setzero_ps();
    }
  };
  c10::ForcedUnroll<ROWS * COLS>{}(loadc);

  auto compute = [&, COLS](auto i, int k) {
    constexpr  int row = i / COLS;
    constexpr  int col = i % COLS;

    if constexpr (col == 0) {
      float aa = static_cast<float>(A[row * lda + k]);
      _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
      va = _mm512_set1_ps(aa);
    }

    if constexpr (row == 0) {
      if constexpr (COLS == 4) {
        // when BLOCK_N = 64, handle each row at a time
        // to reduce de-quantize overhead.
        if constexpr (col == 0) {
          __m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb));
          _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);

          __m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
          vb[0] = _mm512_permutexvar_ps(b32, lut);
          vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
          vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
          vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);

          b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
          vb[1] = _mm512_permutexvar_ps(b32, lut);
          vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
          vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
          vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
        }
      } else {
        __m128i b8 = {{kernel_name}}_convert_int4_to_int8(B + k * ldb + col * 8);
        __m512i b32 = _mm512_cvtepu8_epi32(b8);
        vb[col] = _mm512_permutexvar_ps(b32, lut);
        vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]);
      }
    }

    constexpr int idx = row * COLS + col;
    vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
  };

  for (int k = 0, kb = 0; k < K; ++k) {
    if ({{kernel_name}}_is_block_start(k, k_start, q_group_size)) {
      c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
    }
    c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
  }

  //store to C
  auto storec = [&, COLS](auto i) {
    constexpr int row = i / COLS;
    constexpr int col = i % COLS;
    _mm512_storeu_ps(C + row * ldc + col * 16, vc[i]);
  };
  c10::ForcedUnroll<ROWS * COLS>{}(storec);
}
"""

    def get_kernel_extra_args_declare(self) -> str:
        return (
            "const int64_t q_group_size,\n"
            "    const at::BFloat16* __restrict__ ScaleAndZeros,\n"
            "    const int64_t lds,\n"
            "    int64_t k_start,"
        )

    def get_kernel_extra_args(self, **kwargs) -> list[str]:
        assert "kernel" in kwargs
        assert "qscale_and_zeros" in kwargs
        kernel = kwargs["kernel"]
        qscale_and_zeros = kwargs["qscale_and_zeros"]
        return [
            "group_size,",
            f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),",
            "N * 2,",  # lds
            "k_start,",
        ]

    def is_woq_int4(self):
        return True


@register_micro_gemm(
    *generate_gemm_config(
        VecAMX,
        [  # (block_m, block_n, block_k)
            (16, 32, 32),
            (32, 32, 32),
        ],
        input_dtype=torch.bfloat16,
        input2_dtype=torch.uint8,
        output_dtype=torch.float,
        compute_dtype=torch.float,
        extra_check=check_amx_extra,
    ),
)
class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX):
    """
    This class generates the code for WoQ int4 micro gemm using AMX intrinsics,
    which are available on 4th and newer generations of Intel Xeon.
    Shape of packed weight = [N // 32, K, 16], viewed as [N, K // 2]
    Shape of packed ScalesAndZeros = [K // group_size, N, 2]
    Reuse TEMPLATE_KERNEL of CppMicroGemmAMX.
    """

    TEMPLATE_ENTRY = r"""
inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) {
  // check if (k_start + index) % group_size == 0, assuming group_size = 32/64/128
  return ((k_start + index) & (group_size - 1)) == 0;
}

{{declare_kernel}} {
    {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
    {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2");
    {{kernel.assert_function}}({{block_n}} == 32, "block_n must be 32 for WOQ int4");

    // Create a stack-allocated buffer for tiles of B.
    // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements.
    // we cache K * {{block_n}} elements of dequantized B
    {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}}

    constexpr int BLOCK_K = {{block_k}};
    constexpr int64_t BLOCK_N = {{block_n}};
    constexpr int COLS = BLOCK_N / 16;
    const int PREFETCH_SIZE_K = 16 * 4;
    const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
    const int KB = K / BLOCK_K;

    __m512i b32[COLS * 2];
    __m512 vb[COLS * 2];
    __m512 scale[COLS];
    __m512 zero[COLS];

    // Lookup table to de-quantize int4 values to bf16.
    // Values are dequantized as truly int4 [-8, 7] range;
    //
    // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
    //
    static const __m512 lut = _mm512_set_ps(
        7.0f, 6.0f, 5.0f, 4.0f,
        3.0f, 2.0f, 1.0f, 0.0f,
        -1.0f, -2.0f, -3.0f, -4.0f,
        -5.0f, -6.0f, -7.0f, -8.0f);

    // index for transpose
    static const __m512i idx1 = _mm512_set_epi32(
        30, 28, 26, 24, 22, 20, 18, 16,
        14, 12, 10, 8, 6, 4, 2, 0);
    static const __m512i idx2 = _mm512_set_epi32(
        31, 29, 27, 25, 23, 21, 19, 17,
        15, 13, 11, 9, 7, 5, 3, 1);

    // Indices for VNNI layout conversion
    __m512i idx_low = _mm512_set_epi32(
        0x17,
        0x07,
        0x16,
        0x06,
        0x15,
        0x05,
        0x14,
        0x04,
        0x13,
        0x03,
        0x12,
        0x02,
        0x11,
        0x01,
        0x10,
        0x00);
    __m512i idx_high = _mm512_set_epi32(
        0x1f,
        0x0f,
        0x1e,
        0x0e,
        0x1d,
        0x0d,
        0x1c,
        0x0c,
        0x1b,
        0x0b,
        0x1a,
        0x0a,
        0x19,
        0x09,
        0x18,
        0x08);

    // load scale and zero point
    auto load_scale_and_zeros = [&](int i, int _kb) {
        // load 2x bfloat16 vector
        __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i));
        _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0);

        // convert to 2x f32 vector
        __m512 a, b;
        at::vec::cvtbf16_fp32(t, a, b);

        // transpose scale_and_zero from {16, 2} to {2, 16}
        // inputs:
        //   a: {s0, z0, s1, z1, ..., s7, z7}
        //   b: {s8, z8, s9, z9, ..., s15, z15}
        // output:
        //   scale: {s0, s1, s2, ..., s15}
        //   zero:  {z0, z1, z2, ..., z15}
        scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
        zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
    };

    // Dequantize a B block of 2 * block_n into bf16
    // So, it handles k and k+1 at the same time
    auto dequantize_B = [&](int n) {
        constexpr int64_t ldb_int4 = BLOCK_N / 2; // 16
        for (int k = 0, kb = 0; k < K; k += 2) {
            // Since block_k must be 32 for AMX microkernels, k_start may not be
            // a multiple of q_group_size. In that case, we need to load scales
            // and zero points immediately when k == 0 here
            if ({{kernel_name}}_is_block_start(k, k_start, q_group_size) || k == 0) {
                c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
            }

            _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0);

            // load 256 bits = 64 elements in int4
            __m128i b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + k * ldb_int4));
            b32[0] = _mm512_cvtepu8_epi32(b4);
            b32[1] = _mm512_srli_epi32(b32[0], 4);
            vb[0] = _mm512_permutexvar_ps(b32[0] , lut);
            vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
            vb[1] = _mm512_permutexvar_ps(b32[1], lut);
            vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);

            __m128i b4_2 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4));
            b32[0 + COLS] = _mm512_cvtepu8_epi32(b4_2);
            b32[1 + COLS] = _mm512_srli_epi32(b32[0 + COLS], 4);
            vb[0 + COLS] = _mm512_permutexvar_ps(b32[0 + COLS] , lut);
            vb[0 + COLS] = _mm512_fmadd_ps(vb[0 + COLS], scale[0], zero[0]);
            vb[1 + COLS] = _mm512_permutexvar_ps(b32[1 + COLS], lut);
            vb[1 + COLS] = _mm512_fmadd_ps(vb[1 + COLS], scale[1], zero[1]);

            for (int i = 0; i < COLS; i++) {
                // convert to VNNI
                auto low = _mm512_permutex2var_ps(vb[i], idx_low, vb[i + COLS]);
                auto high = _mm512_permutex2var_ps(vb[i], idx_high, vb[i + COLS]);
                // convert lower 16 float32 values to bfloat16
                auto v0_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(low));
                // convert higher 16 float32 values to bfloat16
                auto v1_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(high));
                // combine the lower 16 and higher 16 bfloat16 values
                auto v = _mm512_castsi256_si512(v0_bf16);
                v = _mm512_inserti64x4(v, v1_bf16, 1);
                // store the VNNI format bfloat16 values
                {{input_t}}* addr = dequantized_B_buf + k * 32 + (i % 2) * 32;
                _mm512_storeu_si512(addr, v);
            }
        }
    };

    for (int64_t n = 0; n < N; n += {{block_n}}) {
        // Dequantize K * block_n int8 B elements into BF16
        dequantize_B(n);
        for (int64_t m = 0; m < M; m += {{block_m}}) {
            int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
            int64_t m_tail = m;
        {%- for num_rows in range(block_m, 0, -16) %}
            {%- if num_rows != block_m %}
            else
        {%- endif %}
            if (block_m >= {{num_rows}}) {
                {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
                    amx_state,
                    A + m * lda,
                    dequantized_B_buf + n * K,
                    C + m * ldc + n,
                    K,
                    lda,
                    {{block_n}},
                    ldc,
                    16
                );
                block_m -= {{num_rows}};
                m_tail += {{num_rows}};
            }
        {%- endfor %}
            if (block_m > 0) {
                {{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
                    amx_state,
                    A + m_tail * lda,
                    dequantized_B_buf + n * K,
                    C + m_tail * ldc + n,
                    K,
                    lda,
                    {{block_n}},
                    ldc,
                    block_m
                );
            }
        } // for m
    } // for n
}
"""

    def get_kernel_extra_args_declare(self) -> str:
        return (
            "AMXState& amx_state,\n"
            "    const int64_t q_group_size,\n"
            "    const c10::BFloat16* __restrict__ ScaleAndZeros,\n"
            "    const int64_t lds,\n"
            "    int64_t k_start,"
        )

    def get_kernel_extra_args(self, **kwargs) -> list[str]:
        assert "kernel" in kwargs
        assert "qscale_and_zeros" in kwargs
        kernel = kwargs["kernel"]
        qscale_and_zeros = kwargs["qscale_and_zeros"]
        return [
            "amx_state,",
            "group_size,",
            f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),",
            "N * 2,",  # lds
            "k_start,",
        ]

    def is_woq_int4(self):
        return True


def create_micro_gemm(
    name,
    m,
    n,
    k,
    input_dtype,
    input2_dtype,
    output_dtype=None,
    compute_dtype=None,
    alpha=1,
    num_threads=-1,
    use_ref=True,
    q_group_size=None,
) -> Optional[CppMicroGemm]:
    """
    Based on the provided info, try to find the config of the micro-kernel that would
    deliver the best performance in terms of lower latency for this case.
    """

    def create_from_config(cls, config: CppMicroGemmConfig):
        return cls(
            name,
            config.input_dtype,
            config.input2_dtype,
            config.output_dtype,
            config.compute_dtype,
            config.register_blocking,
            alpha,
        )

    def skip_amx_kernel_for_woq(dynamic_M):
        # For WoQ GEMM, AMX micro-kernel may not perform well if m is small.
        # Exception: for dynamic shapes, we consider using the AMX micro-kernel.
        if (
            dynamic_M
            or input_dtype != torch.bfloat16
            or input2_dtype not in [torch.int8, torch.uint8]
        ):
            return False
        m_threshold = 5
        return m < m_threshold

    assert isinstance(n, int) or n.is_number, n
    assert isinstance(k, int) or k.is_number, k
    from ..utils import has_free_symbols

    dynamic_M = has_free_symbols((m,))
    m = V.graph.sizevars.size_hint(m, fallback=1) if dynamic_M else m
    assert isinstance(m, int) or m.is_number, m
    if output_dtype is None:
        output_dtype = input_dtype
    if compute_dtype is None:
        compute_dtype = output_dtype
    if num_threads < 0:
        num_threads = parallel_num_threads()
    vec_isa = pick_vec_isa()
    matched_configs = []
    for cls, configs in micro_gemm_configs.items():
        for config in configs:
            if not issubclass(vec_isa.__class__, config.vec_isa_cls):
                continue
            if (
                config.input_dtype == input_dtype
                and config.compute_dtype == compute_dtype
                and config.input2_dtype == input2_dtype
                and config.output_dtype == output_dtype
                # The output_dtype here is the output dtype of the micro-kernel.
                # In some cases, the actual output dtype of the op for which the micro-kernel
                # is being created would be same as that of the activation, but the micro-kernels
                # compute output in Float/int32, which is converted in the GEMM template. This is
                # subject to change in the future.
            ):
                if config.extra_check is not None and not config.extra_check(
                    config,
                    m,
                    n,
                    k,
                    alpha,
                    num_threads,
                    dynamic_M=dynamic_M,
                    q_group_size=q_group_size,
                    vec_isa=vec_isa,
                ):
                    continue
                block_m, block_n, block_k = config.register_blocking
                if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq(dynamic_M):
                    continue
                # Criteria on the ranking of configurations
                # 1. ISA: AMX > VEC
                # 2. Dividable by block sizes (block_m, block_n, block_k)
                # 3. Number of mxn blocks is large enough to occupy all the threads
                # 4. Register blocks are larger
                isa_score = 0
                if config.vec_isa_cls == VecAMX:
                    isa_score += 1
                dividable_score = 0
                if m % block_m == 0:
                    dividable_score += 1
                if n % block_n == 0:
                    dividable_score += 1
                if k % block_k == 0:
                    dividable_score += 1
                occupancy_score = 0
                n_blocks = (n + block_n - 1) // block_n
                total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m)
                if n_blocks >= num_threads:
                    occupancy_score += 1
                if total_mxn_blocks >= num_threads:
                    occupancy_score += 1
                register_bytes = (
                    block_m * block_n * config.compute_dtype.itemsize
                    + (block_m * block_k + block_k * block_n)
                    * config.input_dtype.itemsize
                )
                size_score = register_bytes
                # if number of mxn blocks can not occupy all the threads,
                # we favor smaller register blocks.
                if occupancy_score == 0:
                    size_score = 0 - register_bytes
                matched_configs.append(
                    (
                        (isa_score, dividable_score, occupancy_score, size_score),
                        cls,
                        config,
                    )
                )
    if len(matched_configs) == 0:
        if use_ref:
            return CppMicroGemmRef(
                name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
            )
        else:
            return None
    # TODO(jgong5): allow autotuning on choices of configs
    return create_from_config(*max(matched_configs, key=operator.itemgetter(0))[1:])
